import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch
from tqdm import tqdm
import wandb


def tmp_draw(logits_plane, save_path, num, range, not_record):
    logits_plane = logits_plane
    steps = logits_plane.shape[0]
    x = np.linspace(-range, range, steps)
    y = np.linspace(-range, range, steps)
    X, Y = np.meshgrid(x, y)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(X, Y, logits_plane, cmap='RdBu')

    ax.set_xlabel('Input space 1')
    ax.set_ylabel('Input space 2')
    ax.set_zlabel('Std of logits')
    # ax.set_title(f'Energy std. {logits_plane.mean()}')
    ax.patch.set_alpha(0.0)

    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False

    np.save(os.path.join(save_path, f'landscape/logits_landscape_{num}.npy'), logits_plane)
    plt.savefig(os.path.join(save_path, f'landscape/logits_landscape_{num}.png'))
    plt.clf()
    if not_record is False:
        wandb.log({
            'landscape/logits_landscape' : \
            wandb.Image(os.path.join(save_path, f'landscape/logits_landscape_{num}.png'))})


@torch.no_grad()
def plot_logit_landscape(
        model, loader, save_path, 
        type='images', steps=40, landscape_range=0.1, not_record=False):
    assert type in ['images', 'parameters']

    torch.manual_seed(42)
    if not os.path.exists(os.path.join(save_path, 'landscape')):
        os.makedirs(os.path.join(save_path, 'landscape'))

    if type is 'images':
        shape = loader.dataset[0][0].shape
        rv1 = torch.randn_like(torch.zeros(shape)).cuda()
        rv1 = rv1 / rv1.norm()
        rv2 = torch.randn_like(torch.zeros(shape)).cuda()
        rv2 = rv2 / rv2.norm()

        logits_plane = np.zeros((steps, steps))
        n_lbls = torch.zeros(len(loader.dataset.classes))
        for idx, (x, y) in tqdm(enumerate(loader)):
            if idx == 1: break;
            x, y = x.cuda(), y.cuda()
            for x_step in range(steps):
                for y_step in range(steps):
                    rv1_weight = -landscape_range + landscape_range*2 * (x_step / (steps - 1))
                    rv2_weight = -landscape_range + landscape_range*2 * (y_step / (steps - 1))
                    x_ = x + (rv1 * rv1_weight + rv2 * rv2_weight).unsqueeze(0)
                    std_logits = model(x_).std(1)
                    logits_plane[x_step, y_step] += std_logits.mean().item()
        tmp_draw(logits_plane, save_path, 'total', landscape_range, not_record=not_record)
    else:
        wrapped_model = ModelInterface(model, loader)
        direction1, direction2 = orthogonal_directions(model)
        logits_plane = compute_logits_plane(
            wrapped_model, direction1, direction2, steps=steps)
        plot_3d_logits_plane(logits_plane, save_path)
        np.save(os.path.join(save_path, 'logit_landscape.npy'), logits_plane)


class ModelInterface:
    def __init__(self, model, loader):
        self.model = model
        self.loader = loader
        self.init_weights = self.get_weights()
        self.x, _ = next(iter(self.loader))

    def get_weights(self):
        return [p.data.clone() for p in self.model.parameters()]

    def set_weights(self, weights):
        for p, w in zip(self.model.parameters(), weights):
            p.data.copy_(w)

    def eval(self, x):
        logits = self.model(x)
        return logits.logsumexp(1)


@torch.no_grad()
def compute_logits_plane(model: ModelInterface, direction1, direction2, steps=40):
    assert steps > 0, "steps must be a positive integer"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.model.to(device)
    logits_plane = np.zeros((steps, steps))
    
    for x_step in tqdm(range(steps)):
        for y_step in range(steps):
            x_weight = -0.01 + 0.02 * (x_step / (steps - 1))
            y_weight = -0.01 + 0.02 * (y_step / (steps - 1))

            new_weights = [init_w + x_weight * (d1) + y_weight * (d2) \
                           for init_w, d1, d2 in zip(model.init_weights, direction1, direction2)]
            model.set_weights(new_weights)
            logits = 0
            count = 0

            logits += model.eval(model.x.cuda()).sum()
            count += model.x.size(0)
            logits /= count
            logits_plane[x_step, y_step] = logits.item()
    return logits_plane


def random_direction(model):
    direction = []
    for param in model.parameters():
        direction.append(torch.rand_like(param).cuda())
    return direction
   

def orthogonal_directions(model):
    direction1 = [torch.randn_like(param) for param in model.parameters()]
    direction2 = []
    for d1 in direction1:
        d2 = torch.randn_like(d1)
        d2 = d2 - (torch.dot(d2.view(-1), d1.view(-1)) \
                   / torch.dot(d1.view(-1), d1.view(-1))) * d1
        direction2.append(d2)
    return direction1, direction2


def plot_3d_logits_plane(logits_plane, save_path):
    steps = logits_plane.shape[0]
    x = np.linspace(-1, 1, steps)
    y = np.linspace(-1, 1, steps)
    X, Y = np.meshgrid(x, y)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(X, Y, logits_plane, cmap='RdBu')

    ax.set_xlabel('Direction 1')
    ax.set_ylabel('Direction 2')
    ax.set_zlabel('Logits')
    ax.set_title('3D Logits Landscape')
    plt.savefig(os.path.join(save_path, 'logits_landscape.png'))
    wandb.log({
        'landscape/logits_landscape' : \
            wandb.Image(os.path.join(save_path, 'logits_landscape.png'))})